{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#线下0.8905046367886152 \n",
    "#线上0.8726"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 103,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "from datetime import datetime\n",
    "#显示所有列\n",
    "pd.set_option('display.max_columns', None)\n",
    "#显示所有行\n",
    "pd.set_option('display.max_rows', None)\n",
    "\n",
    "train_list = os.listdir('../hy_round1_train_20200102/')\n",
    "test_list = os.listdir('../hy_round1_testA_20200102/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #读取所有文档放入一个文件里\n",
    "# train_list = os.listdir('./hy_round1_train_20200102/')\n",
    "# test_list = os.listdir('./hy_round1_testA_20200102/')\n",
    "\n",
    "# train = pd.DataFrame()\n",
    "# for i in train_list:\n",
    "#     file = pd.read_csv('./hy_round1_train_20200102/'+i)\n",
    "#     train = train.append(file)\n",
    "    \n",
    "# test = pd.DataFrame()\n",
    "# for i in test_list:\n",
    "#     file = pd.read_csv('./hy_round1_train_20200102/'+i)\n",
    "#     test = test.append(file)\n",
    "\n",
    "# train.to_csv('./train.csv',index=False)    \n",
    "# test.to_csv('./test.csv',index=False)\n",
    "\n",
    "# train = pd.read_csv('./train.csv')\n",
    "# test = pd.read_csv('./test.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 特征工程"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_feature_base(demo):\n",
    "    demo.rename(columns={'渔船ID': \"ID\", \"速度\": \"speed\", \"方向\": \"direction\"}, inplace=True)\n",
    "    demo_train = pd.DataFrame()\n",
    "    \n",
    "    #分割time特征得到day, hour, quarter\n",
    "    tmp = pd.DataFrame()\n",
    "    tmp['time'] = pd.to_datetime(demo['time'],format='%m%d %H:%M:%S')\n",
    "    demo[\"month\"] = tmp[\"time\"].dt.month\n",
    "    demo[\"day\"] = tmp[\"time\"].dt.day\n",
    "    demo[\"hour\"] = tmp[\"time\"].dt.hour\n",
    "    del tmp\n",
    "\n",
    "    #按时间排序\n",
    "    demo.sort_values([\"time\"],inplace=True, ascending=False)\n",
    "\n",
    "    #计算作业持续时间\n",
    "    start = demo.iloc[-1]['time']\n",
    "    end = demo.iloc[0]['time']\n",
    "    diff = datetime.strptime(str(end),\"%m%d %H:%M:%S\") - datetime.strptime(str(start),\"%m%d %H:%M:%S\")\n",
    "\n",
    "    #构建时间起始日,小时\n",
    "    demo_train['ID'] = [demo['ID'][0]]\n",
    "    demo_train['start_date'] = int(start[2:4])\n",
    "    demo_train['start_hour'] = int(start[5:7])\n",
    "    demo_train['end_date'] = int(end[2:4])\n",
    "    demo_train['end_hour'] = int(end[5:7])\n",
    "    demo_train['work_days'] = diff.days\n",
    "    demo_train['work_seconds'] = diff.seconds\n",
    "    \n",
    "    #unique, min, max, mean, std, median, mode特征: 方向, 速度, x, y\n",
    "    for s in ['x', 'y', 'speed', 'direction']:\n",
    "        temp = demo.groupby('ID')[s].agg({'nunique_' + s: 'nunique', 'min_' + s: 'min', 'max_' + s: 'max', 'std_' + s: 'std', 'median_' + s: 'median', 'mode_' + s: lambda x: np.mean(pd.Series.mode(x))}).reset_index()\n",
    "        demo_train = pd.merge(demo_train,temp, on='ID',how='left')\n",
    "    \n",
    "    #构建x,y坐标交互特征\n",
    "    demo_train['x_max-min'] = demo_train['max_x'] - demo_train['min_x']\n",
    "    demo_train['y_max-min'] = demo_train['max_y'] - demo_train['min_y']\n",
    "    demo_train['rec_area'] = demo_train['y_max-min'] * demo_train['x_max-min']\n",
    "\n",
    "    return demo_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 106,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# #构建训练集特征\n",
    "# train = pd.DataFrame()\n",
    "# for file in tqdm(train_list):\n",
    "#     demo = pd.read_csv('../hy_round1_train_20200102/' + file)\n",
    "#     demo_train = create_feature_base(demo)\n",
    "#     demo_train['type'] = demo['type']\n",
    "#     train = train.append(demo_train)\n",
    "\n",
    "# train.to_csv('../input/train_demo1.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 107,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #构建测试集\n",
    "# test = pd.DataFrame()\n",
    "# for file in tqdm(test_list):\n",
    "#     demo = pd.read_csv('../hy_round1_testA_20200102/' + file)\n",
    "#     demo_test = create_feature_base(demo)\n",
    "#     test = test.append(demo_test)\n",
    "\n",
    "# test['type'] = '测试'\n",
    "# test.to_csv('../input/test_demo1.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 108,
   "metadata": {},
   "outputs": [],
   "source": [
    "# #合并数据集合\n",
    "# data = train.append(test).reset_index(drop=True)\n",
    "# data.to_csv('../input/data_demo1.csv', index=False)\n",
    "data = pd.read_csv('../input/data_demo1.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 109,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "type\n",
       "刺网    1018\n",
       "围网    1621\n",
       "拖网    4361\n",
       "测试    2000\n",
       "Name: ID, dtype: int64"
      ]
     },
     "execution_count": 109,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据分布\n",
    "data.groupby('type')['ID'].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 110,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mem. usage decreased to  0.84 Mb (65.0% reduction)\n"
     ]
    }
   ],
   "source": [
    "#降低内存使用\n",
    "def reduce_mem_usage(df, verbose=True):\n",
    "    numerics = ['int16', 'int32', 'int64', 'float16', 'float32', 'float64']\n",
    "    start_mem = df.memory_usage().sum() / 1024**2    \n",
    "    for col in df.columns:\n",
    "        col_type = df[col].dtypes\n",
    "        if col_type in numerics:\n",
    "            c_min = df[col].min()\n",
    "            c_max = df[col].max()\n",
    "            if str(col_type)[:3] == 'int':\n",
    "                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:\n",
    "                    df[col] = df[col].astype(np.int8)\n",
    "                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:\n",
    "                    df[col] = df[col].astype(np.int16)\n",
    "                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:\n",
    "                    df[col] = df[col].astype(np.int32)\n",
    "                elif c_min > np.iinfo(np.int64).min and c_max < np.iinfo(np.int64).max:\n",
    "                    df[col] = df[col].astype(np.int64)  \n",
    "            else:\n",
    "                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:\n",
    "                    df[col] = df[col].astype(np.float16)\n",
    "                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:\n",
    "                    df[col] = df[col].astype(np.float32)\n",
    "                else:\n",
    "                    df[col] = df[col].astype(np.float64)    \n",
    "    end_mem = df.memory_usage().sum() / 1024**2\n",
    "    if verbose: print('Mem. usage decreased to {:5.2f} Mb ({:.1f}% reduction)'.format(end_mem, 100 * (start_mem - end_mem) / start_mem))\n",
    "    return df\n",
    "\n",
    "data = reduce_mem_usage(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['ID', 'start_date', 'start_hour', 'end_date', 'end_hour', 'work_days',\n",
       "       'work_seconds', 'nunique_x', 'min_x', 'max_x', 'std_x', 'median_x',\n",
       "       'mode_x', 'nunique_y', 'min_y', 'max_y', 'std_y', 'median_y', 'mode_y',\n",
       "       'nunique_speed', 'min_speed', 'max_speed', 'std_speed', 'median_speed',\n",
       "       'mode_speed', 'nunique_direction', 'min_direction', 'max_direction',\n",
       "       'std_direction', 'median_direction', 'mode_direction', 'x_max-min',\n",
       "       'y_max-min', 'rec_area', 'type'],\n",
       "      dtype='object')"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#所有特征\n",
    "data.columns"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 模型训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 112,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import KFold\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "import lightgbm as lgb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "metadata": {},
   "outputs": [],
   "source": [
    "#分离训练集和测试集\n",
    "train = data[data.type!=\"测试\"]\n",
    "test = data[data.type==\"测试\"]\n",
    "\n",
    "#特征选择,X,y分离\n",
    "train_x = train[[i for i in train.columns if i not in ['ID', 'time', 'type']]]\n",
    "test_x = test[[i for i in test.columns if i not in ['ID', 'time', 'type']]]\n",
    "\n",
    "train_y = train['type']\n",
    "\n",
    "#label和type互相装化\n",
    "label2type = dict(zip(range(0, len(set(train_y))), sorted(list(set(train_y)))))\n",
    "type2label = dict(zip(sorted(list(set(train_y))), range(0, len(set(train_y)))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "cv_pred = []\n",
    "oof = train[['ID']]\n",
    "cms = np.zeros((len(set(train_y)),len(set(train_y))))   #混淆矩阵\n",
    "skf = StratifiedKFold(n_splits=5, random_state=27, shuffle=True)\n",
    "\n",
    "feature_importances = pd.DataFrame()\n",
    "feature_importances['feature'] = train_x.columns\n",
    "\n",
    "for index, (train_index, val_index) in enumerate(skf.split(train_x, train_y)):\n",
    "    \n",
    "    lgb_model = lgb.LGBMClassifier(\n",
    "        boosting_type=\"gbdt\", num_leaves=120, reg_alpha=0, reg_lambda=0.,\n",
    "        max_depth=-1, n_estimators=800, objective='multiclass', class_weight='balanced',\n",
    "        subsample=0.9, colsample_bytree=0.5, subsample_freq=1,\n",
    "        learning_rate=0.03, random_state=2018 + index, n_jobs=10, metric=\"None\", importance_type='gain'\n",
    "    )\n",
    "    \n",
    "    train_x1, val_x1, train_y1, val_y1 = \\\n",
    "    train_x.loc[train_index], train_x.loc[val_index], train_y.loc[train_index], train_y.loc[val_index]\n",
    "\n",
    "    lgb_model.fit(train_x1, train_y1)\n",
    "    \n",
    "    #out of folder预测\n",
    "    oof.loc[val_index] = lgb_model.predict(val_x1).reshape(-1, 1)\n",
    "    \n",
    "    #测试集预测\n",
    "    test_y = lgb_model.predict(test_x)\n",
    "    test_y = pd.Series(test_y).map(type2label)\n",
    "    \n",
    "    # Confusion matrix by folds\n",
    "    cms += confusion_matrix(train_y.loc[val_index], oof.loc[val_index])\n",
    " \n",
    "    #特征重要性\n",
    "    feature_importances['fold_{}'.format(index + 1)] = lgb_model.feature_importances_\n",
    "    \n",
    "    if index == 0:\n",
    "        cv_pred = np.array(test_y).reshape(-1, 1)\n",
    "    else:\n",
    "        cv_pred = np.hstack((cv_pred, np.array(test_y).reshape(-1, 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 115,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8905046367886152"
      ]
     },
     "execution_count": 115,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#oof F1-score\n",
    "from sklearn.metrics import f1_score\n",
    "f1_score(y_true=train[['type']], y_pred=oof, average='macro')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 特征重要性分析"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>feature</th>\n",
       "      <th>fold_1</th>\n",
       "      <th>fold_2</th>\n",
       "      <th>fold_3</th>\n",
       "      <th>fold_4</th>\n",
       "      <th>fold_5</th>\n",
       "      <th>importance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>mode_y</td>\n",
       "      <td>20682.852198</td>\n",
       "      <td>19728.443746</td>\n",
       "      <td>21525.336604</td>\n",
       "      <td>14244.821267</td>\n",
       "      <td>24239.394632</td>\n",
       "      <td>100420.848448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>median_y</td>\n",
       "      <td>17187.220501</td>\n",
       "      <td>15457.868834</td>\n",
       "      <td>19345.935073</td>\n",
       "      <td>18988.698626</td>\n",
       "      <td>15185.382910</td>\n",
       "      <td>86165.105945</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>max_y</td>\n",
       "      <td>14128.904336</td>\n",
       "      <td>20433.079812</td>\n",
       "      <td>11560.784785</td>\n",
       "      <td>17087.094964</td>\n",
       "      <td>14545.502632</td>\n",
       "      <td>77755.366528</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>32</th>\n",
       "      <td>rec_area</td>\n",
       "      <td>15076.704301</td>\n",
       "      <td>16941.076005</td>\n",
       "      <td>16017.196457</td>\n",
       "      <td>13737.769574</td>\n",
       "      <td>13252.981691</td>\n",
       "      <td>75025.728028</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>min_y</td>\n",
       "      <td>15784.493651</td>\n",
       "      <td>11868.272302</td>\n",
       "      <td>13447.578895</td>\n",
       "      <td>13493.037586</td>\n",
       "      <td>14046.694014</td>\n",
       "      <td>68640.076448</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>mode_x</td>\n",
       "      <td>9545.258781</td>\n",
       "      <td>11794.820588</td>\n",
       "      <td>10158.369567</td>\n",
       "      <td>10879.565570</td>\n",
       "      <td>9784.579231</td>\n",
       "      <td>52162.593738</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>min_x</td>\n",
       "      <td>9664.097930</td>\n",
       "      <td>8894.327560</td>\n",
       "      <td>10350.630752</td>\n",
       "      <td>10425.898882</td>\n",
       "      <td>7376.927732</td>\n",
       "      <td>46711.882857</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>median_x</td>\n",
       "      <td>8304.285040</td>\n",
       "      <td>7497.652763</td>\n",
       "      <td>7795.518364</td>\n",
       "      <td>9039.518383</td>\n",
       "      <td>8156.122067</td>\n",
       "      <td>40793.096617</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>x_max-min</td>\n",
       "      <td>5449.322565</td>\n",
       "      <td>9367.484138</td>\n",
       "      <td>5502.382237</td>\n",
       "      <td>8614.358753</td>\n",
       "      <td>10735.520256</td>\n",
       "      <td>39669.067950</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>22</th>\n",
       "      <td>median_speed</td>\n",
       "      <td>8055.542794</td>\n",
       "      <td>7475.198664</td>\n",
       "      <td>7395.621818</td>\n",
       "      <td>8811.334456</td>\n",
       "      <td>7020.237288</td>\n",
       "      <td>38757.935021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>31</th>\n",
       "      <td>y_max-min</td>\n",
       "      <td>8690.936007</td>\n",
       "      <td>5132.424246</td>\n",
       "      <td>11006.883088</td>\n",
       "      <td>6064.450526</td>\n",
       "      <td>7848.842827</td>\n",
       "      <td>38743.536694</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>max_x</td>\n",
       "      <td>6607.532696</td>\n",
       "      <td>5996.615924</td>\n",
       "      <td>6813.647172</td>\n",
       "      <td>7594.343244</td>\n",
       "      <td>7589.249968</td>\n",
       "      <td>34601.389005</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>21</th>\n",
       "      <td>std_speed</td>\n",
       "      <td>6707.523022</td>\n",
       "      <td>5647.230342</td>\n",
       "      <td>6103.021901</td>\n",
       "      <td>6704.242616</td>\n",
       "      <td>6990.691938</td>\n",
       "      <td>32152.709818</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>std_y</td>\n",
       "      <td>4864.492763</td>\n",
       "      <td>6415.892577</td>\n",
       "      <td>5786.596568</td>\n",
       "      <td>6405.760005</td>\n",
       "      <td>4878.838259</td>\n",
       "      <td>28351.580172</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>24</th>\n",
       "      <td>nunique_direction</td>\n",
       "      <td>5597.728505</td>\n",
       "      <td>5167.891043</td>\n",
       "      <td>5273.411296</td>\n",
       "      <td>5589.145646</td>\n",
       "      <td>5026.182418</td>\n",
       "      <td>26654.358907</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>nunique_x</td>\n",
       "      <td>5304.575925</td>\n",
       "      <td>4039.640144</td>\n",
       "      <td>4445.915064</td>\n",
       "      <td>3755.580426</td>\n",
       "      <td>4858.908079</td>\n",
       "      <td>22404.619637</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>std_x</td>\n",
       "      <td>4288.605176</td>\n",
       "      <td>3956.118903</td>\n",
       "      <td>4097.521161</td>\n",
       "      <td>4396.480065</td>\n",
       "      <td>4151.176053</td>\n",
       "      <td>20889.901358</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>27</th>\n",
       "      <td>std_direction</td>\n",
       "      <td>3540.231003</td>\n",
       "      <td>3932.346655</td>\n",
       "      <td>3858.876384</td>\n",
       "      <td>3616.731592</td>\n",
       "      <td>3750.268969</td>\n",
       "      <td>18698.454602</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>28</th>\n",
       "      <td>median_direction</td>\n",
       "      <td>3720.182728</td>\n",
       "      <td>3809.963364</td>\n",
       "      <td>3551.002735</td>\n",
       "      <td>3420.227107</td>\n",
       "      <td>3352.652002</td>\n",
       "      <td>17854.027937</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>work_seconds</td>\n",
       "      <td>3670.210676</td>\n",
       "      <td>3528.908306</td>\n",
       "      <td>2968.041955</td>\n",
       "      <td>3557.140554</td>\n",
       "      <td>3779.412512</td>\n",
       "      <td>17503.714002</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>nunique_speed</td>\n",
       "      <td>3088.501268</td>\n",
       "      <td>2927.000700</td>\n",
       "      <td>3139.982073</td>\n",
       "      <td>3050.365031</td>\n",
       "      <td>3443.305184</td>\n",
       "      <td>15649.154255</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>20</th>\n",
       "      <td>max_speed</td>\n",
       "      <td>2651.263624</td>\n",
       "      <td>2914.617214</td>\n",
       "      <td>3012.534731</td>\n",
       "      <td>3310.553608</td>\n",
       "      <td>2645.591340</td>\n",
       "      <td>14534.560517</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>23</th>\n",
       "      <td>mode_speed</td>\n",
       "      <td>2465.455528</td>\n",
       "      <td>2956.258674</td>\n",
       "      <td>2891.740633</td>\n",
       "      <td>2251.088898</td>\n",
       "      <td>2900.680825</td>\n",
       "      <td>13465.224558</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>nunique_y</td>\n",
       "      <td>2877.383817</td>\n",
       "      <td>2138.794822</td>\n",
       "      <td>2058.472249</td>\n",
       "      <td>2833.310045</td>\n",
       "      <td>2346.332646</td>\n",
       "      <td>12254.293579</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>max_direction</td>\n",
       "      <td>1606.976317</td>\n",
       "      <td>1828.716566</td>\n",
       "      <td>1723.262989</td>\n",
       "      <td>1812.298514</td>\n",
       "      <td>1662.094803</td>\n",
       "      <td>8633.349190</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>end_date</td>\n",
       "      <td>1296.333302</td>\n",
       "      <td>1179.805020</td>\n",
       "      <td>1071.990257</td>\n",
       "      <td>1210.670411</td>\n",
       "      <td>1222.590425</td>\n",
       "      <td>5981.389415</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>start_date</td>\n",
       "      <td>1138.501123</td>\n",
       "      <td>1087.042091</td>\n",
       "      <td>1143.227655</td>\n",
       "      <td>1279.215578</td>\n",
       "      <td>1197.072905</td>\n",
       "      <td>5845.059352</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>29</th>\n",
       "      <td>mode_direction</td>\n",
       "      <td>385.939097</td>\n",
       "      <td>454.109515</td>\n",
       "      <td>377.964062</td>\n",
       "      <td>313.802121</td>\n",
       "      <td>437.030916</td>\n",
       "      <td>1968.845711</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>min_speed</td>\n",
       "      <td>420.096775</td>\n",
       "      <td>307.688778</td>\n",
       "      <td>367.562958</td>\n",
       "      <td>381.427200</td>\n",
       "      <td>329.453323</td>\n",
       "      <td>1806.229033</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>start_hour</td>\n",
       "      <td>365.361140</td>\n",
       "      <td>292.444237</td>\n",
       "      <td>317.462876</td>\n",
       "      <td>309.093489</td>\n",
       "      <td>334.871881</td>\n",
       "      <td>1619.233624</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>end_hour</td>\n",
       "      <td>222.733173</td>\n",
       "      <td>175.529204</td>\n",
       "      <td>297.095434</td>\n",
       "      <td>237.863914</td>\n",
       "      <td>248.280656</td>\n",
       "      <td>1181.502381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>work_days</td>\n",
       "      <td>25.613910</td>\n",
       "      <td>16.322973</td>\n",
       "      <td>23.447697</td>\n",
       "      <td>47.472431</td>\n",
       "      <td>30.963150</td>\n",
       "      <td>143.820162</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25</th>\n",
       "      <td>min_direction</td>\n",
       "      <td>12.635765</td>\n",
       "      <td>18.991657</td>\n",
       "      <td>7.149753</td>\n",
       "      <td>10.403372</td>\n",
       "      <td>14.415788</td>\n",
       "      <td>63.596336</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "              feature        fold_1        fold_2        fold_3        fold_4  \\\n",
       "17             mode_y  20682.852198  19728.443746  21525.336604  14244.821267   \n",
       "16           median_y  17187.220501  15457.868834  19345.935073  18988.698626   \n",
       "14              max_y  14128.904336  20433.079812  11560.784785  17087.094964   \n",
       "32           rec_area  15076.704301  16941.076005  16017.196457  13737.769574   \n",
       "13              min_y  15784.493651  11868.272302  13447.578895  13493.037586   \n",
       "11             mode_x   9545.258781  11794.820588  10158.369567  10879.565570   \n",
       "7               min_x   9664.097930   8894.327560  10350.630752  10425.898882   \n",
       "10           median_x   8304.285040   7497.652763   7795.518364   9039.518383   \n",
       "30          x_max-min   5449.322565   9367.484138   5502.382237   8614.358753   \n",
       "22       median_speed   8055.542794   7475.198664   7395.621818   8811.334456   \n",
       "31          y_max-min   8690.936007   5132.424246  11006.883088   6064.450526   \n",
       "8               max_x   6607.532696   5996.615924   6813.647172   7594.343244   \n",
       "21          std_speed   6707.523022   5647.230342   6103.021901   6704.242616   \n",
       "15              std_y   4864.492763   6415.892577   5786.596568   6405.760005   \n",
       "24  nunique_direction   5597.728505   5167.891043   5273.411296   5589.145646   \n",
       "6           nunique_x   5304.575925   4039.640144   4445.915064   3755.580426   \n",
       "9               std_x   4288.605176   3956.118903   4097.521161   4396.480065   \n",
       "27      std_direction   3540.231003   3932.346655   3858.876384   3616.731592   \n",
       "28   median_direction   3720.182728   3809.963364   3551.002735   3420.227107   \n",
       "5        work_seconds   3670.210676   3528.908306   2968.041955   3557.140554   \n",
       "18      nunique_speed   3088.501268   2927.000700   3139.982073   3050.365031   \n",
       "20          max_speed   2651.263624   2914.617214   3012.534731   3310.553608   \n",
       "23         mode_speed   2465.455528   2956.258674   2891.740633   2251.088898   \n",
       "12          nunique_y   2877.383817   2138.794822   2058.472249   2833.310045   \n",
       "26      max_direction   1606.976317   1828.716566   1723.262989   1812.298514   \n",
       "2            end_date   1296.333302   1179.805020   1071.990257   1210.670411   \n",
       "0          start_date   1138.501123   1087.042091   1143.227655   1279.215578   \n",
       "29     mode_direction    385.939097    454.109515    377.964062    313.802121   \n",
       "19          min_speed    420.096775    307.688778    367.562958    381.427200   \n",
       "1          start_hour    365.361140    292.444237    317.462876    309.093489   \n",
       "3            end_hour    222.733173    175.529204    297.095434    237.863914   \n",
       "4           work_days     25.613910     16.322973     23.447697     47.472431   \n",
       "25      min_direction     12.635765     18.991657      7.149753     10.403372   \n",
       "\n",
       "          fold_5     importance  \n",
       "17  24239.394632  100420.848448  \n",
       "16  15185.382910   86165.105945  \n",
       "14  14545.502632   77755.366528  \n",
       "32  13252.981691   75025.728028  \n",
       "13  14046.694014   68640.076448  \n",
       "11   9784.579231   52162.593738  \n",
       "7    7376.927732   46711.882857  \n",
       "10   8156.122067   40793.096617  \n",
       "30  10735.520256   39669.067950  \n",
       "22   7020.237288   38757.935021  \n",
       "31   7848.842827   38743.536694  \n",
       "8    7589.249968   34601.389005  \n",
       "21   6990.691938   32152.709818  \n",
       "15   4878.838259   28351.580172  \n",
       "24   5026.182418   26654.358907  \n",
       "6    4858.908079   22404.619637  \n",
       "9    4151.176053   20889.901358  \n",
       "27   3750.268969   18698.454602  \n",
       "28   3352.652002   17854.027937  \n",
       "5    3779.412512   17503.714002  \n",
       "18   3443.305184   15649.154255  \n",
       "20   2645.591340   14534.560517  \n",
       "23   2900.680825   13465.224558  \n",
       "12   2346.332646   12254.293579  \n",
       "26   1662.094803    8633.349190  \n",
       "2    1222.590425    5981.389415  \n",
       "0    1197.072905    5845.059352  \n",
       "29    437.030916    1968.845711  \n",
       "19    329.453323    1806.229033  \n",
       "1     334.871881    1619.233624  \n",
       "3     248.280656    1181.502381  \n",
       "4      30.963150     143.820162  \n",
       "25     14.415788      63.596336  "
      ]
     },
     "execution_count": 116,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "feature_importances['importance'] = feature_importances[[i for i in feature_importances.columns if i != 'feature']].apply(lambda x: x.sum(), axis=1)\n",
    "feature_importances.sort_values(by='importance',ascending=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 混淆矩阵"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Normalized confusion matrix\n",
      "0.8565815324165029\t0.0893909626719057\t0.054027504911591355\n",
      "0.03392967304133251\t0.8729179518815546\t0.0931523750771129\n",
      "0.011923870671864251\t0.04563173584040358\t0.9424443934877321\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAARoAAAEdCAYAAADTtqgCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAZ2klEQVR4nO3deXxU9b3/8ddnskzCJIjgBqhlXxUVUVGo7KtCFAOiwlXR8nOphWLVy1VxRSyWXlzqgv1Zl7rhUnAFBQmiKAoquCCoCF7RumCbTCYkJDPf+8cM6KUmBPTLmaHv5+ORB5OZycmHA3nNmXNmTsw5h4iIT6GgBxCRPZ9CIyLeKTQi4p1CIyLeKTQi4p1CIyLeZQc9QH1khSMulN846DEySsgcbZo3JpSlx5KdkYgnMINQVlbQo2ScRDzO+++9+61zbt/tb8uI0ITyG5PdfVLQY2SUgpwEc+88i0hBYdCjZJRYeZRwThYFWm87rbw8ysH7773hx27Tw52IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQ7EB+OJsnry1i/vRiZl81nNycLE7v15Hnpp3C/OnFNGsS4eD9Clk4YxSzrxqOGeRkh7jjtwOCHj0t1NTUcN64MRSfOJDrpkwmWlbGqOGDGDlsENGyMgAmnn8u8Xg84EnTz2cb1tO+RTOGDe7LiGGDKSsr46ShAyga0p+y1Lq7cPy4jFh35pwLdgCzmcDFzrla11ZOo4NcdvdJu3Gq7xX1aEOnXzRh2kPLuHT00Wz8tpyehzbn/P9+cdt9JozoyrLVX/LLLgeyYMUGjupwACs/+YZlq78MZGaAgpwEr915FpGCwsBmAHh6zhOsX/cJF026lMsvmUjjJvvQsnUbAMK5YcJ5YaLRKCcXnxronFvFyqOEc7IoCHi9QTI0U6+Zwl333A/AU3OeZEtVFQC54TB5eWGiZVFOGTU6yDG3KS+PcvD+e69wznXb/rbAt2iccxPrikzQ1n3xT8I5WQA0Kghz8H6FZIWM56adwh/P700oZFRU1RDOzSaSl0PCObq02i/QyKSTDes/pWPnQwDofOhh7LPvvlRVVrK5IkZ+g3yemfMkRSNGBjxl+nrl5RKGDujF7bfOJD8/n8rKSmIVMfLz85n75BOcXDwq6BHrZbeHxsxCZvZnM1tsZs+bWYmZZe/uOerr4y/+yVEdmrLirv+ga9v9iSccudlZDJ38BBVVNQw7tjWzS9ZwRv+OOOD4Lgcyu+RDZpzfm8tOOzro8QPXuk1bXnt1CQBLX1nMd5s2seqdFbz37kqi0Sj9Bg1h6tWXc92UydTU1AQ8bXrZ/4CmvLFyNU89v5DFixbStGkz3nl7Be+tSq67AYOHcs2Vk7nq8svSft0FsUVTBHztnOsFnFDbncxsvJktN7PliS2x3Tfddsb078SCFes58v/dz7w3PiUrZCx593MASlb+D+0PakxprIrxM17gD4++SZvme9PuwMY8ueQjQma0PXDvwGZPBwOHnEjl5s2MGj6I3NwwBzRtxrQZtzJ1+kwWvTifSCRCh46d6dCxM68sfinocdNKOBwmEomQnZ3NoCEn8OHqD/jDzNv4/YybWfjCPCKRCB07HULHToeweNHCoMetUxChaQcsBXDOJWq7k3NulnOum3OuWyg3stuG254ZfBetBODbss044JCW+wBwWKt92fD30m33PX/44dz59Eoiedk450g4R0FeThBjp42srCym3jST2U/NJysrRK++/QGY8/ijDB9RzOaKzYRCIUKhELHy4B5Q0lE0Gt12edlrS2nRqhUAT8x+hJNOGcnmzT9Yd7H0XndBPGVZA3QHnjGzwPcR7ciji9bwwH8N5fR+HamuSTD2hme5dPTRzJ9ezKbSSm7521sANGyQS/N9Cli9YROxymruu2wIm8oquenRNwP+GwTryy828utfnUUoFKJ49Bk0a34g8XiclxctZOYdf6astJRxZ4zEOce9Dz8R9Lhp5fVXl3DDdVeTG86l+7E96HbUMcTjcUpeWsCfZt1DWWkpY0aPwDnHg7PnBD1unXb7UadUXO4G2gLlQAOgv3Ou1ieZQR51ylTpctQp06TTUadMU9dRp92+RZN6unTO7v6+IhKctH/qIiKZT6EREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxLjvoAeojZI6CnETQY2SUSE6Cilh50GNknIpYOfGcrKDHyEh1/X/LiNC0O6gJL9xzbtBjZJRYrJz+E+4nVm1Bj5JRIjmOudNPJ5QTD3qUjFO5pfZ1lhGhCYVCFBQWBj1GxolVG+VbFJqd1SBSQKRA/99+TtpHIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLeKTQi4p1CIyLe1fq7t83sAcBtfzXgnHP/4XUqEdmj1Boa4IrdNoWI7NFqDY1zbgOAmRnQF2hGcosG4H7/o4nInqI++2hmA72ByUA7YIDPgURkz1Of0OzrnLsS+No5dwWwl+eZRGQPU5/Q1JhZGNhoZlOA5p5nEpE9TH1CM9g5VwWMB94DhvsdKT198P579O/dk0H9enH++HF8tmEDA/sez+iRJ5FIJNiyZQsXnndu0GOmjfxwNk9OLWb+jNOYfe0ITjyuDfNnnMb8Gaex7tELGXZcWwob5PLc9FN5/qbRFDbIBWDWJUMJhWwHS//3UFNTw/njxlA8bCDXT5lMtKyMUUWDGDl8ENGyMgAmXnAu8Xg84El3zJzb/gj2dncw+5dD2c65n21nsJnNBC52ztW6tg47/Aj3yusrfq5vuUuqq6vJyckB4Pzx42jXrgPH9ejJkiWL6dd/IMvffIMuhx3OMd2PDXTOrcqjUbqceQflW4L5oS3q2Y5OLfZh2l+Xcunpx/LB+m94ZunHALx861iGXPIIA45qSTgnC4Cq6jiVW2po2CDM7EWrA5kZoCDXsXTWr4gUFAY2w1ZPz3mC9es+4aJJl3LFpRNp3GQfWrZqA0BuOEw4HKY8GuWk4lMDnjQpVh6l3UH7rHDOddv+tvps0VjqIwQcBvT7OYdzzk2sKzLpYmtkAMLhMAmXoLKqkopYjFAoxLur3kmbyKSDdV/8Y1tEGhWE+a5sMwAtmu7FV/+IEauspqKyhnBuNg3ycqmorGHE8R14rCS4yKSbz9Z/SsfOhwDQ+dDDaLLPvlRVVVJRESM/P59n5j7J8BEjA56yfnYYGufcfamPe51zFwNNdvQ1ZnaTmR1iZgPN7O3UdfeZ2TFmVmJmr5rZ2anrS8ysrtfzpI3nnnmKY47swjdff82ZZ53Dww8+AGYseXkxxaNGc8mkCUy/cWrQY6aFjzf+g6M6NmPFn8+ha7sDeO39jQCc1LM9T736EQAvvbWeru0OoEvr/WgYyWXesk+4/tzeTP1Vb7L09InWbdvy+qtLAFi6ZDHffbeJlW+v4P13VxKNRuk3cAg3XH0510+ZTE1NTcDT1m2HoTGz68zs2tTH3fX5GmAp0AM4DvjSzAqB/YFrSe7j6QmcYWa5dXzf8Wa23MyWb9q0qT5/F++GnjicZStW0bRZc15b+gp33v0XJv3uMj75eC0frV3DSSNOIZFI8NHaNUGPGrgxAw9hwfJPOfLc/8+8Zes4rX9nAIZ2b82zryVDUxNPMPGWF5l024sMPKoV5ZureX/9N7y//hv6dG0R4PTpYcDgE6ms3MyookHkhsM0bdqMaTNu5frpM1m0YD6RSIT2nTrTvlNnXln8UtDj1qk+0VgALEz9eZtzbmg9vuZVkpFpDTwIFAFfkXzq9RSwCDgA2Le2BTjnZjnnujnnujVpssONKO+qqqq2XW7YsCF5+fkAzLrjT4w/70IqYhWELEQoFKK8vDyoMdOGmfFdWSUA35ZWsFckzP57R9hSk9h2/Van9unE4yUf0iAvm0TCkUg4CvJzfmyx/1aysrK4fvpMZs+dT1ZWiOP79AdgzuOPMvzkYjZv3pz8P2chYrFYwNPWrT5PWa5wzm17kZ6ZPeycO62uL3DOfW1mTYHPSUZnDnAHsB9Q7JyLmVmOc646+cLj9LfghXncdstMAFq3aUO//gMpLS1l48bP6dipMw0iEc45cwyNmzTm4kv+M+Bpg/fowg944MoiTh/QmeqaBGOvn8uI49vzzNKP/s/9QiGj75EtGH/TczSM5DL7mhGYGcVXPhHQ5Onjyy82ctH4s7BQiOLRZ9Cs+YHE43FeLlnIzNv/TFlpKePGjATn+MtD6b2+aj3qZGZ9SL71YCzfv+UgG+jhnOu1wwWb3Qe845z7bzPbCAwG8oFpJLekvnPOnWJmJUB/51ytTzLT4ahTpgn6qFOmSqejTpmmrqNOdW3RrAMSQCuST50AqoEb6/NNnXNn/uDyD1/k12+7+/Wuz/JEJHPt6E2VG8xsP+fcYtj2Bsti4LHdNJ+I7AHqszP4vK0XXPJ51nl13FdE5F/UJzRhM9sbwMwaA3l+RxKRPU19jjpdBsw1s617jXVIRUR2Sn22aN4guQN4NfBPYIjXiURkj1PXOYNHAcOABsB8oK1z7md9n5OI/Huoa4tmKrAFuME5NwuoquO+IiK1qjU0zrm2wM3AMDN7CuhgZn1SJ8ESEam3OncGO+dWAasAzKwVcAowBejjfzQR2VPU+xfIOefWOeducs4pMiKyU/SbKkXEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLxTaETEO4VGRLyr8zdVpotEIkF5NBr0GBklFisnkuOCHiPjRHIcFbHyoMfISHWtt4wIjQHZ2dr42hnZWSHmzxxLJFIQ9CgZJRYrZ9DZNxGrUqR3ViRstd6WEaEJZWVRWFgY9BgZJ7s6ToHW206LVTnKKxNBj5GBat8Y0GaCiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwqNiHin0IiIdwrNTvjiiy849qiuNCrIo6amhg0bNtC3V09GjigikUiwZcsWzvvVOUGPmVY+27Cedi2aceKgvowYNpiysjKKhg5g+JD+lJWVAXDB+HHE4/GAJ00PWVkh7r/xbObN+g1TJxRtu/6kfofz0fPXAVAYyeO5Oy/i+bsuojCSB8Csa8YQClkgM9dHto+FmlkLoJVz7iUfyw9K48aNee6FhZxafDIAf3vycaZOm86Sl0t45+23efONZZw17tyAp0w/vfv2Z9Y99wPw1JwnGXPmOABKXlpAOBymb/+BZGVlBTli2ijqcxir1m7kD/e8wB8vG8mh7Zrz7tqNnNz/cD7/6h8A9OvegfvmvLbtcmVVNQteW00i4YIcvU6+tmhaAH09LTsweXl57L333ts+b5DfgKqqSmKxGKFQiFUr36H7sccGOGF6emVxCUP69+L2W2eSn59PVWUlFbEY+fn5zP3bE4woHhX0iGmj5YFNeG/tRgBWrvmcY7q0ZHDPzix8/cNtIamo3EI4nE2D/FwqKrcwYsARPDb/rSDH3iFfoRkPjDWzuJldD2BmZ5nZWanLt5jZy2b2jJnt5WkG70aNPo0HH7gfM+PlxSWMGn0akyb+hhtvuD7o0dLG/gc05c1Vq3l63kJKXlrIAU2b8fZbK3h31Uqi0SgDBw/l6isnM+Xyy6ipqQl63MCtXf81vzyyDQC9urWlUcMGjBl2DA8/++a2+7y07EO6djqYLu0PpGEkj3lL3uf63wxn6oQisrLSc2+Ir6lmAQ8A/ba/wcyOAiLOueOBR4DzPM3gXaNGjbj7nnv53aX/yccfrWXtmjWMOKWYRCLB2jVrgh4vLYTDYSKRCNnZ2QwacgIfrv6AGTffxvQ/3syCF+YRiUTo1OkQOnU6hMWLFgY9buCeffld8vNyee7Oi6iqruHrTWW8vnId1TXf78OqqUkwcdpsJv3+MQb26ET55ire/+RL3v/kS/oc3T7A6WvnO38/fNK4dU9Va2Drdt5yoM2PfaGZjTez5Wa2fNO333oc8ae74/bbOO+CXxOriGEWIhQKUV5eHvRYaSEajW67vOz1pbRs2QqAx2c/wsmnjGTz5s1YKISFQsRisaDGTBuJhGPS7x9j6Hm3Eo87mu/fiBN6Hcrc2y6gY+umXHXBidvue+rgI3n8hbdokJdLIuFIJBwFDcIBTl87LzuDgWogCygFmqauOxRYBawDBqau6wZ88mMLcM7NIrllxOFHdE2LvVzV1dUUnTiEd1etZNjQQVxz3Q2079CBjZ9/TqfOnYlEIpw59nSaNGnCJZdNDnrctPDaq0u44dqryQ3n0v24HnQ7+hji8TglLy3g9ln3UFpayphTR+Cc46HH5gQ9buCa7bsXf7nhLBKJBA8+8wZ/fXoZ05gHwMJ7fss1tz8DQChk9O3egfFX/ZWGBXnM/uN4zIziiXcFOX6tzLmf/2c4td/laeAroIDkls0mYKFz7l4z+xPQBYgCpzvn/lnX8g4/oqt7/c303tmVbqLRKFXVcQoKC4MeJaOUR6McUTSF8spE0KNknIK8EN++fusK51y37W/zskXjnCsFjq/j9gt9fF8RSU/puYtaRPYoCo2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeKfQiIh3Co2IeJcd9AD1kYjHiUajQY+RUWLl5VTVxIMeI+PEYuVEwoYeg3deJGx8W8tt5pzbrcPsCjP7BtgQ9By12AdqXb9SO623XZPu6+0Xzrl9t78yI0KTzsxsuXOuW9BzZBqtt12TqetN24ci4p1CIyLeKTQ/3aygB8hQWm+7JiPXm/bRiIh32qIREe8UGhHxTqEREe8Uml1kZhPM7JgffG5BzpMpzOwSMzs06DkyiZll/M+pdgb/BGZWAAx3zj0U9CyZwszCqYujnXP3BTpMBjAzc845M2sG/BewFPjGOfdiwKPtlIwv5e62dcsl9Q9/H/A7M5uy/e3yf219VHbOVQE5wJlmNiHYqdJfKjINgduB9UAFMMLMBgc62E5SaHbSD/7h7wTud851BTpujY3TJuK/SD0qJ8ysuZldCvQBRgC9zOzXAY+XlrZ7wDJgC3Cvc24O8CzQNJDBdpFCs2uyge+AtwCcc6cBg83s3ECnSkNmFkrFuRB4GvgQmJj6uAA42cwuCHLGdJRaZw3MbCywP/AUMMrMGgMHAL80s1CmbEFnxGki0lApsAo4ysxqgDYk31G7KNCp0lBqS6YhcDBwM7AQmAA8B0SBE4AmwU2YnlIB2YtkVHoCZUAWcBdQCPzWOZcIbsKdo53Bu8jMfgEUA51I/hD92jm3Jtip0scPdmJmAYcDt5E8xUEVcBLJR+kbgSLn3HfBTZo+zKw1sAmIA2cCDwO5QBGQIBnpUiDknPs6qDl3hULzE5hZDtCI5D/8V0HPk25SR+UeAn4LtAfGAy2BPwFnA2c75z4MbsL0kQrykSQDMxZ4HniP5HubjO+3Bu/OxP2ACo14kfrBaQLcBGwm+YMzFPgbyR+gGufc34ObMP2Y2V7AIyR3/M4A9gOOAT4AhgATnHNfBjfhrtM+GvnZmVk2cDHJHb8bgHySh2a7AnnOuTODmy59OedKzewc4FBgAN/vxxoKTM7UyIBCIx4452rM7DGgA8n9MocBr5J8dD4oyNnSnXPuCzMrJbkjeALJp+bjnXOfBTvZT6OnTuKVmbUEbgE+BS52zlUHPFJGSD317ADEnHPrAx7nJ1NoxLvUq6jj2mH+70uhERHv9MpgEfFOoRER7xQaEfFOoRER7xQa2Slm1tvMNphZiZnNNbO8nfzaq1OXb63lPi3MrO/OLk/Sm0Iju+IB51xvkmd7K4bkmyh35pQFzrmLarmpBVCv0Ejm0CuD5ad4B5hrZv2A5sAZZnY+yVAkgHHOufVmdg/Jd7hvAP4HwMxecc71NLMewHSS7++5k+Q7lXuY2bHOuX6pE4rtcHmS3rRFIz/F8cAaYK1zbiDJc6c0T23tXAhMNrOjSb5Yrz/wyY8sY+upIvoAj5F8t/IDqcgcugvLkzSkLRrZFWNTWyIfkDzz24rU9R2B3mZWkvr8S6AV8Hbq8xXAsdsvzDn3berPxHbPvnZpeZJ+FBrZFQ84564ASO2M3XqmtzXAC1v3v6TO19OV5DmCAY74kWU5M2vinNuUOoF5Nckzye3q8iQN6amT/GyccyuBv6eOSC0ieWKrZUDYzBYC7X7kyyYDT6fuP5LkuWp6mNmju7g8SUN6r5OIeKctGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe8UGhHxTqEREe/+FyFB47oov2c4AAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "#!/usr/bin/env python\n",
    "# _*_ coding:utf-8 _*_\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_Matrix(cm, classes, title=None,  cmap=plt.cm.Blues):\n",
    "    plt.rc('font',family='Times New Roman',size='8')   # 设置字体样式、大小\n",
    "    \n",
    "    # 按行进行归一化\n",
    "    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
    "    print(\"Normalized confusion matrix\")\n",
    "    str_cm = cm.astype(np.str).tolist()\n",
    "    for row in str_cm:\n",
    "        print('\\t'.join(row))\n",
    "    # 占比1%以下的单元格，设为0，防止在最后的颜色中体现出来\n",
    "    for i in range(cm.shape[0]):\n",
    "        for j in range(cm.shape[1]):\n",
    "            if int(cm[i, j]*100 + 0.5) == 0:\n",
    "                cm[i, j]=0\n",
    "\n",
    "    fig, ax = plt.subplots()\n",
    "    im = ax.imshow(cm, interpolation='nearest', cmap=cmap)\n",
    "    # ax.figure.colorbar(im, ax=ax) # 侧边的颜色条带\n",
    "    \n",
    "    ax.set(xticks=np.arange(cm.shape[1]),\n",
    "           yticks=np.arange(cm.shape[0]),\n",
    "           xticklabels=classes, yticklabels=classes,\n",
    "           title=title,\n",
    "           ylabel='Actual',\n",
    "           xlabel='Predicted')\n",
    "\n",
    "    # 通过绘制格网，模拟每个单元格的边框\n",
    "    ax.set_xticks(np.arange(cm.shape[1]+1)-.5, minor=True)\n",
    "    ax.set_yticks(np.arange(cm.shape[0]+1)-.5, minor=True)\n",
    "    ax.grid(which=\"minor\", color=\"gray\", linestyle='-', linewidth=0.2)\n",
    "    ax.tick_params(which=\"minor\", bottom=False, left=False)\n",
    "\n",
    "    # 将x轴上的lables旋转45度\n",
    "    plt.setp(ax.get_xticklabels(), rotation=45, ha=\"right\",\n",
    "             rotation_mode=\"anchor\")\n",
    "\n",
    "    # 标注百分比信息\n",
    "    fmt = 'd'\n",
    "    thresh = cm.max() / 2.\n",
    "    for i in range(cm.shape[0]):\n",
    "        for j in range(cm.shape[1]):\n",
    "            if int(cm[i, j]*100 + 0.5) > 0:\n",
    "                ax.text(j, i, format(int(cm[i, j]*100 + 0.5) , fmt) + '%',\n",
    "                        ha=\"center\", va=\"center\",\n",
    "                        color=\"white\"  if cm[i, j] > thresh else \"black\")\n",
    "    fig.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "plot_Matrix(cms,classes=['ci', 'wei', 'tuo'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 生成结果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2000 \n",
      " 拖网    1254\n",
      "围网     477\n",
      "刺网     269\n",
      "Name: predict, dtype: int64\n",
      "        ID predict\n",
      "8942  7000      围网\n",
      "8256  7001      拖网\n",
      "8027  7002      围网\n",
      "7458  7003      拖网\n",
      "7124  7004      围网\n"
     ]
    }
   ],
   "source": [
    "#投票策略筛选预测结果\n",
    "submit = []\n",
    "for line in cv_pred:\n",
    "    submit.append(np.argmax(np.bincount(line)))\n",
    "\n",
    "#预测结果\n",
    "res = test[['ID']]\n",
    "res['predict'] = submit\n",
    "res['predict'] = res['predict'].map(label2type)\n",
    "\n",
    "print(len(res), '\\n',res.predict.value_counts())\n",
    "print(res.sort_values('ID').head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#保存模型\n",
    "res.sort_values('ID').to_csv('../output/demo1_submission.csv', index=False, header=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3 (fastai)",
   "language": "python",
   "name": "python3_fastai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
